import os, sys, re, time, copy, pickle, glob
import numpy as np
from tqdm import tqdm

import torch
from torch.optim import AdamW
import torch.nn.functional as F

from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from load_data import *

class Retriever():
    """
    Retriever class for retrieving data from the database.
    """

    def __init__(self,
                dataset_folder="/data/shared/zhexu/LLM_textual",
                corpus_name='ogbn-arxiv', 
                encoder_name='sentence-transformers/sentence-t5-base',
                lr=1e-5,
                seed=0,
                device='cuda'):
        
        self.dataset_folder = dataset_folder

        self.passage_list = load_prototype_list(self.dataset_folder, corpus_name, seed=seed)

        self.model = SentenceTransformer(encoder_name).to(device)
        print(f"RAG encoder: {encoder_name}")
        self.optimizer = AdamW(self.model.parameters(), lr=lr)

        total_parameters = sum(p.numel() for p in self.model.parameters())
        print(f"Total number of parameters of retriever: {total_parameters}")

        tmp_encoder_name = encoder_name.replace('sentence-transformers/','')
        feature_file = f"{self.dataset_folder}/{corpus_name}/{corpus_name}_{seed}_{tmp_encoder_name}_emb.pt"
        if os.path.exists(feature_file):
            print("Loading from existing file...")
            self.corpus_emb = torch.load(feature_file).to(device)
        else:
            print("Generating new corpus embeddings...")
            self.corpus_emb = self.model.encode(self.passage_list, convert_to_tensor=True, show_progress_bar=True).to(device)
            torch.save(self.corpus_emb, feature_file)
        self.device = device

    def retrieve_no_grad(self, queries, topk=3):
        self.model.eval()
        encoding = self.model.tokenizer(queries, padding=True, truncation=True, return_tensors='pt').to(self.device)
        queries_emb = self.model(encoding)['sentence_embedding']

        # queries_emb = self.model.encode(queries, convert_to_tensor=True)

        # need to update later, and currently use the vanilla inner product.
        scores = torch.matmul(queries_emb, self.corpus_emb.T)
        top_k_scores, top_k_indices = torch.topk(scores, topk, dim=1)
        return top_k_scores, top_k_indices
    
    def retrieve(self, queries_emb, topk=3):
        scores = torch.matmul(queries_emb, self.corpus_emb.T)
        top_k_scores, top_k_indices = torch.topk(scores, topk, dim=1)
        return top_k_scores, top_k_indices
    
    def train(self, queries, topk=3, ground_truth=None):
        # Deprecated
        self.model.train()

        encoding = self.model.tokenizer(queries, padding=True, truncation=True, return_tensors='pt')
        queries_emb = self.model(encoding)['sentence_embedding']

        scores = torch.matmul(queries_emb, self.corpus_emb.T)
        normalized_scores = scores / scores.sum(dim=1, keepdim=True)

        log_ground_truth = torch.log(ground_truth + 1e-10)
        cross_entropy_loss = F.kl_div(log_ground_truth, normalized_scores, reduction='batchmean')
        cross_entropy_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        print(ground_truth)
        print(normalized_scores)
        print(f"Training loss: {cross_entropy_loss.item()}")
        print()

if __name__ == '__main__':

    encoder_name = 'sentence-transformers/sentence-t5-base'
    # encoder_name = 'sentence-transformers/all-roberta-large-v1'
    # encoder_name = 'msmarco-bert-base-dot-v5'

    lr = 1e-3

    retriever = Retriever(encoder_name=encoder_name, lr=lr)
    dataset_name = 'ogbn-arxiv'
    dataset_folder = "/data/shared/zhexu/LLM_textual/"
    topk = 3

    title_list = load_title_list(dataset_folder, dataset_name)
    content_list = load_content_list(dataset_folder, dataset_name)
    passage_list = [title_list[i] + '\n' + content_list[i] for i in range(len(title_list))]

    retrival_list = passage_list[0:1]

    # top_k_scores, top_k_indices = retriever.retrieve_passage(retrival_list, topk=topk)

    random_tensor = torch.ones(1, 169343)
    normalized_tensor = random_tensor / random_tensor.sum(dim=1, keepdim=True)

    for _ in range(100):
        retriever.train(retrival_list, topk=topk, ground_truth=normalized_tensor)

    # for i in range(len(retrival_list)):
    #     print(f"Query: {retrival_list[i]}")
    #     for j in range(topk):
    #         print(f"Top {j+1} passage: {passage_list[top_k_indices[i][j]]}")
    #         print(f"Top {j+1} score: {top_k_scores[i][j]}")
    #         print("")
